{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "56987c98-2ca1-41fb-83c1-c3455a8fed46",
   "metadata": {
    "tags": []
   },
   "source": [
    "Generating Simulation Data\n",
    "==========================\n",
    "\n",
    "This notebook contains all the scripts we used to generate simulated data which fed into the figures in the manuscript.\n",
    "\n",
    "Keep in mind that all the data that results from these sweeps has been generated already - if you would like\n",
    "to look at and work with the exact simulation results that we used to generate the figures in the paper, you\n",
    "do not need to rerun these cells (although you certainly can!). Just be careful not to accidentally overwrite\n",
    "the pregenerated data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1cb13c71",
   "metadata": {},
   "outputs": [],
   "source": [
    "from RPI_tools import pytorch_tools, tf_tools\n",
    "import time\n",
    "import numpy as np\n",
    "import torch as t\n",
    "from tqdm import tqdm\n",
    "import os\n",
    "\n",
    "# Edit this line to point to the location of the data folder on your system.\n",
    "data_folder_prefix = 'data/'"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "50592d2c-bda2-4394-b8bc-5b50c151bb6a",
   "metadata": {
    "tags": []
   },
   "source": [
    "Sweep over R\n",
    "------------\n",
    "\n",
    "The following cells generate simulated data, and do the relevant iterative reconstructions, for a sweep\n",
    "over the resolution ratio R at fixed photon fluence. The first cell sets the parameters for the sweep, the\n",
    "second cell runs the simulations, and the third cell runs the relevant iterative reconstructions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c7681ef2-54ae-47d6-b6ba-e2133863e9c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix = data_folder_prefix + 'Simulated/R_Sweep/'\n",
    "\n",
    "#check if the folder exist\n",
    "isExist = os.path.exists(prefix)\n",
    "if not isExist:\n",
    "    # Create a new directory if it does not exist \n",
    "    os.makedirs(prefix)\n",
    "    print(\"The new directory is created!\")\n",
    "\n",
    "Rs = [0.25, 0.5, 1, 2] # The set of resolution ratios to study\n",
    "training = 100 #4000 # The number of training images to generate\n",
    "testing = 100 # The number of testing images to generate\n",
    "photon_level = 1e4 # The photon level to study, per pixel in the RPI reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df0c4d31-f5eb-4cea-a0d6-d640762a4b5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "for R in Rs:\n",
    "        \n",
    "    # The maximum frequency contained in the probe, in pixels, for the given\n",
    "    # resolution ratio and and object size of 128\n",
    "    probe_maxk = 128//R\n",
    "\n",
    "    # We then calculate the size of the array on which we have to simulate\n",
    "    # the probe to capture all the frequency components. \n",
    "    field_size = int(np.ceil(probe_maxk) + 1) * 2\n",
    "\n",
    "    # This generates an ideal BLR probe with the given parameters\n",
    "    probe = tf_tools.generate_blr_probe([field_size,field_size], probe_maxk)\n",
    "    \n",
    "    # We generate a set of training data - diffraction patterns and ground truth images.\n",
    "    expanded_probe, (tr_patterns, tr_images), (test_patterns, test_images) = tf_tools.generate_imagenet_phase_data(\n",
    "        probe, n_train=training, n_test = testing, n_photons_per_pix = photon_level, data_folder_prefix=data_folder_prefix)\n",
    "    \n",
    "    print('Data Generated')\n",
    "\n",
    "    num_tr = tr_patterns.shape[0]\n",
    "    num_test = test_patterns.shape[0]\n",
    "    num_rows = tr_patterns.shape[1]\n",
    "    num_cols = tr_patterns.shape[2]\n",
    "    \n",
    "    print('Training Pattern Dimensions:', num_tr, 'by', num_rows, 'by', num_cols)\n",
    "    print('Testing Pattern Dimensions:', num_test, 'by', num_rows, 'by', num_cols)\n",
    "\n",
    "    # This will overwrite the previous files at each photon level, but all photon levels\n",
    "    # use the same images & probe so it is not an issue other than being inefficient.\n",
    "    np.save(prefix + 'test_images-R-%0.2f.npy' % R, test_images)\n",
    "    np.save(prefix + 'tr_images.npy' % R, tr_images)\n",
    "    np.save(prefix + 'probe-R-%0.2f.npy' % R, expanded_probe)\n",
    "    \n",
    "    # We do save out the patterns in separate files as they are afffected by the photon level.\n",
    "    np.save(prefix + 'test_patterns-R-%0.2f-phperpix-%d.npy' % (R, photon_level), test_patterns)\n",
    "    np.save(prefix + 'tr_patterns-R-%0.2f-phperpix-%d.npy' % (R, photon_level), tr_patterns)\n",
    "    \n",
    "    print('Simulated Data for R=%0.2f Saved' % R)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "63d13c17-495e-4b4c-a6e1-277ec20f16dd",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "reconstruction_checkpoints = [1, 5, 10, 50, 100, 200, 300, 500, 1000]\n",
    "\n",
    "lr = 0.5 # set the learning rate for iterative reconstruction\n",
    "\n",
    "for R in Rs:\n",
    "\n",
    "    print('Working on Resolution Ratio', R)\n",
    "    expanded_probe = t.from_numpy(np.load(prefix + 'probe-R-%0.2f.npy' % R))\n",
    "\n",
    "    print('Generating Approximants from Training Data')\n",
    "    tr_patterns = t.from_numpy(np.load(prefix + 'tr_patterns-R-%0.2f-phperpix-%d.npy' % (R, photon_level)))\n",
    "    \n",
    "    tr_approximants = []\n",
    "    for pattern in tqdm(tr_patterns):\n",
    "        # This performs a single iteration of the reconstruction algorithm with steepest gradient descent\n",
    "        # for training the deep k-learning framework\n",
    "        approximant, _ = pytorch_tools.reconstruct(pattern, expanded_probe, resolution=256, lr=1, iterations=1, \n",
    "                                                   loss_func=pytorch_tools.amplitude_mse, GPU=True)\n",
    "        tr_approximants.append(np.angle(approximant.numpy()))\n",
    "\n",
    "    tr_approximants = np.array(tr_approximants)\n",
    "    np.save(prefix + 'tr-reconstruction-R-%0.2f-phperpix-%d-iters-%d-lr-%0.2f.npy' \n",
    "            % (R, photon_level, 1, 1), tr_approximants)\n",
    "    \n",
    "    print('Generating Approximants from Testing Data')\n",
    "    test_patterns = t.from_numpy(np.load(prefix + 'test_patterns-R-%0.2f-phperpix-%d.npy' % (R, photon_level)))\n",
    "    \n",
    "    test_approximants = []\n",
    "    for pattern in tqdm(test_patterns):\n",
    "        # This performs a single iteration of the reconstruction algorithm with steepest gradient descent\n",
    "        # for testing the deep k-learning framework\n",
    "        approximant, _ = pytorch_tools.reconstruct(pattern, expanded_probe, resolution=256, lr=1, iterations=1, \n",
    "                                                   loss_func=pytorch_tools.amplitude_mse, GPU=True)\n",
    "        test_approximants.append(np.angle(approximant.numpy()))\n",
    "        \n",
    "    test_approximants = np.array(test_approximants)\n",
    "    np.save(prefix + 'test-reconstruction-R-%0.2f-phperpix-%d-iters-%d-lr-%0.2f.npy' \n",
    "            % (R, photon_level, 1, 1), test_approximants)\n",
    "    \n",
    "    print('Performing Reconstructions on Testing Data')\n",
    "    # Store runtime information for each iteration step\n",
    "    runtimes = np.zeros([2,len(reconstruction_checkpoints)]) \n",
    "    for i, iters in enumerate(reconstruction_checkpoints):\n",
    "        \n",
    "        print('Doing Reconstructions with %d iterations' % iters)\n",
    "        \n",
    "        test_results = []\n",
    "        start = time.time()\n",
    "        for pattern in tqdm(test_patterns):\n",
    "            # This performs <iters> iterations of the reconstruction algorithms\n",
    "            result, _ = pytorch_tools.reconstruct(pattern, expanded_probe, resolution=256, \n",
    "                                                  lr=lr, iterations=iters, loss_func=pytorch_tools.amplitude_mse, GPU=True)\n",
    "            test_results.append(np.angle(result.numpy()))\n",
    "            \n",
    "        runtimes[:,i] = [iters, time.time() - start] # The format is [# iterations, runtime] for each checkpoint\n",
    "\n",
    "        test_results = np.array(test_results)\n",
    "        np.save(prefix + 'test-reconstruction-R-%0.2f-phperpix-%d-iters-%d-lr-%0.2f.npy' % (R, photon_level, iters, lr), test_results)\n",
    "    \n",
    "    np.save(prefix + 'runtimes-R-%0.2f-phperpix-%d-lr-%0.2f.npy' % (R, photon_level, lr), runtimes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "048e0d4e-da0a-41c2-9e24-1020ca1a4d63",
   "metadata": {},
   "source": [
    "Sweep over Photon Shot Noise at Fixed R\n",
    "---------------------------------------\n",
    "\n",
    "The following cells generate simulated data, and do the relevant iterative reconstructions, for a sweep\n",
    "over the photon fluence at a fixed resolution rato R. The first cell sets the parameters for the sweep, the\n",
    "second cell runs the simulations, and the third cell runs the relevant iterative reconstructions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f032fd9-70c3-4dc8-863a-e7a426e84c2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix = data_folder_prefix + 'Simulated/Fixed_R_Noise_Sweep/'\n",
    "\n",
    "#check if the folder exist\n",
    "isExist = os.path.exists(prefix)\n",
    "if not isExist:\n",
    "    # Create a new directory if it does not exist \n",
    "    os.makedirs(prefix)\n",
    "    print(\"The new directory is created!\")\n",
    "\n",
    "R = 0.5 # The resolution ratio to study\n",
    "training = 100 #4000 # The number of training images to generate\n",
    "testing = 100 # The number of testing images to generate\n",
    "photon_levels = [1e-2, 1e-1, 1, 1e1, 1e2, 1e3] # The set of photon levels to study, per pixel in the RPI reconstruction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "09446b32-9230-4721-b3e2-8d2050809561",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "for photon_level in photon_levels:\n",
    "        \n",
    "    # The maximum frequency contained in the probe, in pixels, for the given\n",
    "    # resolution ratio and and object size of 128\n",
    "    probe_maxk = 128//R\n",
    "\n",
    "    # We then calculate the size of the array on which we have to simulate\n",
    "    # the probe to capture all the frequency components. \n",
    "    field_size = int(np.ceil(probe_maxk) + 1) * 2\n",
    "\n",
    "    # This generates an ideal BLR probe with the given parameters\n",
    "    probe = tf_tools.generate_blr_probe([field_size,field_size], probe_maxk)\n",
    "    \n",
    "    # We generate a set of training data - diffraction patterns and ground truth images.\n",
    "    expanded_probe, (tr_patterns, tr_images), (test_patterns, test_images) = tf_tools.generate_imagenet_phase_data(\n",
    "        probe, n_train=training, n_test = testing, n_photons_per_pix = photon_level, data_folder_prefix=data_folder_prefix)\n",
    "    \n",
    "    print('Data Generated')\n",
    "\n",
    "    num_tr = tr_patterns.shape[0]\n",
    "    num_test = test_patterns.shape[0]\n",
    "    num_rows = tr_patterns.shape[1]\n",
    "    num_cols = tr_patterns.shape[2]\n",
    "    \n",
    "    print('Training Pattern Dimensions:', num_tr, 'by', num_rows, 'by', num_cols)\n",
    "    print('Testing Pattern Dimensions:', num_test, 'by', num_rows, 'by', num_cols)\n",
    "\n",
    "    # This will overwrite the previous files at each photon level, but all photon levels\n",
    "    # use the same images & probe so it is not an issue other than being inefficient.\n",
    "    np.save(prefix + 'test_images-R-%0.2f.npy' % R, test_images)\n",
    "    np.save(prefix + 'tr_images-R-%0.2f.npy' % R, tr_images)\n",
    "    np.save(prefix + 'probe-R-%0.2f.npy' % R, expanded_probe)\n",
    "    \n",
    "    # We do save out the patterns in separate files as they are afffected by the photon level.\n",
    "    np.save(prefix + 'test_patterns-R-%0.2f-phperpix-%0.2f.npy' % (R, photon_level), test_patterns)\n",
    "    np.save(prefix + 'tr_patterns-R-%0.2f-phperpix-%0.2f.npy' % (R, photon_level), tr_patterns)\n",
    "    \n",
    "    print('Simulated Data for Photon Level %0.2f Saved' % photon_level)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c50d7661",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "reconstruction_checkpoints = [1, 5, 10, 50, 100, 200, 300, 500, 1000]\n",
    "lr = 0.5\n",
    "\n",
    "for photon_level in photon_levels:\n",
    "\n",
    "    print('Working on Photon Level', photon_level)\n",
    "    expanded_probe = t.from_numpy(np.load(prefix + 'probe-R-%0.2f.npy' % R))\n",
    "\n",
    "    print('Generating Approximants from Training Data')\n",
    "    tr_patterns = t.from_numpy(np.load(prefix + 'tr_patterns-R-%0.2f-phperpix-%0.2f.npy' % (R, photon_level)))\n",
    "    \n",
    "    tr_approximants = []\n",
    "    for pattern in tqdm(tr_patterns):\n",
    "        # This performs a single iteration of the reconstruction algorithm\n",
    "        approximant, _ = pytorch_tools.reconstruct(pattern, expanded_probe, 256, \n",
    "                                                   lr=1, iterations=1, loss_func=pytorch_tools.amplitude_mse, GPU=True)\n",
    "        tr_approximants.append(np.angle(approximant.numpy()))\n",
    "\n",
    "    tr_approximants = np.array(tr_approximants)\n",
    "    np.save(prefix + 'tr-reconstruction-R-%0.2f-phperpix-%0.2f-iters-%d-lr-%0.2f.npy' % (R, photon_level, 1, 1), tr_approximants)\n",
    "    \n",
    "    print('Generating Approximants from Testing Data')\n",
    "    test_patterns = t.from_numpy(np.load(prefix + 'test_patterns-R-%0.2f-phperpix-%0.2f.npy' % (R, photon_level)))\n",
    "    test_approximants = []\n",
    "    for pattern in tqdm(test_patterns):\n",
    "        # This performs a single iteration of the reconstruction algorithm\n",
    "        approximant, _ = pytorch_tools.reconstruct(pattern, expanded_probe, 256, \n",
    "                                                   lr=1, iterations=1, loss_func=pytorch_tools.amplitude_mse, GPU=True)\n",
    "        test_approximants.append(np.angle(approximant.numpy()))\n",
    "\n",
    "    test_approximants = np.array(test_approximants)\n",
    "    np.save(prefix + 'test-reconstruction-R-%0.2f-phperpix-%0.2f-iters-%d-lr-%0.2f.npy' % (R, photon_level, 1, 1), test_approximants)\n",
    "    \n",
    "    print('Performing Reconstructions on Testing Data')\n",
    "    runtimes = np.zeros([2,len(reconstruction_checkpoints)]) # This will store runtime information for each iteration step\n",
    "    \n",
    "    for i, iters in enumerate(reconstruction_checkpoints):\n",
    "        \n",
    "        print('Doing Reconstructions with %d iterations' % iters)\n",
    "        \n",
    "        test_results = []\n",
    "        start = time.time()\n",
    "        for pattern in tqdm(test_patterns):\n",
    "            # This performs <iters> iterations of the reconstruction algorithms\n",
    "            result, _ = pytorch_tools.reconstruct(pattern, expanded_probe, 256, lr=lr, \n",
    "                                                  iterations=iters, loss_func=pytorch_tools.amplitude_mse, GPU=True)\n",
    "            test_results.append(np.angle(result.numpy()))\n",
    "            \n",
    "        runtimes[:,i] = [iters, time.time() - start] # The format is [# iterations, runtime] for each checkpoint\n",
    "\n",
    "        test_results = np.array(test_results)\n",
    "        np.save(prefix + 'test-reconstruction-R-%0.2f-phperpix-%0.2f-iters-%d-lr-%0.2f.npy' % (R, photon_level, iters, lr), test_results)\n",
    "    \n",
    "    np.save(prefix + 'runtimes-R-%0.2f-phperpix-%0.2f-lr-%0.2f.npy' % (R, photon_level, lr), runtimes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11d03f1b-9234-4949-9cd2-660eada9653a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
